# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for contrib.seq2seq.python.ops.attention_wrapper."""
# pylint: disable=unused-import,g-bad-import-order
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: enable=unused-import

import sys
import functools

import numpy as np

from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper as wrapper
from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
from tensorflow.contrib.seq2seq.python.ops import basic_decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.platform import test
from tensorflow.python.util import nest

# pylint: enable=g-import-not-at-top


# for testing
AttentionWrapperState = wrapper.AttentionWrapperState  # pylint: disable=invalid-name
LSTMStateTuple = core_rnn_cell.LSTMStateTuple  # pylint: disable=invalid-name
BasicDecoderOutput = basic_decoder.BasicDecoderOutput  # pylint: disable=invalid-name
float32 = np.float32
int32 = np.int32
array = np.array


class AttentionWrapperTest(test.TestCase):

  def assertAllClose(self, *args, **kwargs):
    kwargs["atol"] = 1e-4  # For GPU tests
    kwargs["rtol"] = 1e-4  # For GPU tests
    return super(AttentionWrapperTest, self).assertAllClose(
        *args, **kwargs)

  def testAttentionWrapperState(self):
    num_fields = len(wrapper.AttentionWrapperState._fields)  # pylint: disable=protected-access
    state = wrapper.AttentionWrapperState(*([None] * num_fields))
    new_state = state.clone(time=1)
    self.assertEqual(state.time, None)
    self.assertEqual(new_state.time, 1)

  def _testWithAttention(self,
                         create_attention_mechanism,
                         expected_final_output,
                         expected_final_state,
                         attention_mechanism_depth=3,
                         alignment_history=False,
                         expected_final_alignment_history=None,
                         attention_layer_size=6,
                         name=""):
    encoder_sequence_length = [3, 2, 3, 1, 0]
    decoder_sequence_length = [2, 0, 1, 2, 3]
    batch_size = 5
    encoder_max_time = 8
    decoder_max_time = 4
    input_depth = 7
    encoder_output_depth = 10
    cell_depth = 9

    if attention_layer_size is not None:
      attention_depth = attention_layer_size
    else:
      attention_depth = encoder_output_depth

    decoder_inputs = np.random.randn(batch_size, decoder_max_time,
                                     input_depth).astype(np.float32)
    encoder_outputs = np.random.randn(batch_size, encoder_max_time,
                                      encoder_output_depth).astype(np.float32)

    attention_mechanism = create_attention_mechanism(
        num_units=attention_mechanism_depth,
        memory=encoder_outputs,
        memory_sequence_length=encoder_sequence_length)

    with self.test_session(use_gpu=True) as sess:
      with vs.variable_scope(
          "root",
          initializer=init_ops.random_normal_initializer(stddev=0.01, seed=3)):
        cell = core_rnn_cell.LSTMCell(cell_depth)
        cell = wrapper.AttentionWrapper(
            cell,
            attention_mechanism,
            attention_layer_size=attention_layer_size,
            alignment_history=alignment_history)
        helper = helper_py.TrainingHelper(decoder_inputs,
                                          decoder_sequence_length)
        my_decoder = basic_decoder.BasicDecoder(
            cell=cell,
            helper=helper,
            initial_state=cell.zero_state(
                dtype=dtypes.float32, batch_size=batch_size))

        final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)

      self.assertTrue(
          isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
      self.assertTrue(
          isinstance(final_state, wrapper.AttentionWrapperState))
      self.assertTrue(
          isinstance(final_state.cell_state, core_rnn_cell.LSTMStateTuple))

      self.assertEqual((batch_size, None, attention_depth),
                       tuple(final_outputs.rnn_output.get_shape().as_list()))
      self.assertEqual((batch_size, None),
                       tuple(final_outputs.sample_id.get_shape().as_list()))

      self.assertEqual((batch_size, attention_depth),
                       tuple(final_state.attention.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.c.get_shape().as_list()))
      self.assertEqual((batch_size, cell_depth),
                       tuple(final_state.cell_state.h.get_shape().as_list()))

      if alignment_history:
        state_alignment_history = final_state.alignment_history.stack()
        # Remove the history from final_state for purposes of the
        # remainder of the tests.
        final_state = final_state._replace(alignment_history=())  # pylint: disable=protected-access
        self.assertEqual((None, batch_size, encoder_max_time),
                         tuple(state_alignment_history.get_shape().as_list()))
      else:
        state_alignment_history = ()

      sess.run(variables.global_variables_initializer())
      sess_results = sess.run({
          "final_outputs": final_outputs,
          "final_state": final_state,
          "state_alignment_history": state_alignment_history,
      })

      print("Copy/paste (%s)\nexpected_final_output = " % name,
            sess_results["final_outputs"])
      sys.stdout.flush()
      print("Copy/paste (%s)\nexpected_final_state = " % name,
            sess_results["final_state"])
      sys.stdout.flush()
      print("Copy/paste (%s)\nexpected_final_alignment_history = " % name,
            np.asarray(sess_results["state_alignment_history"]))
      sys.stdout.flush()
      nest.map_structure(self.assertAllClose, expected_final_output,
                         sess_results["final_outputs"])
      nest.map_structure(self.assertAllClose, expected_final_state,
                         sess_results["final_state"])
      if alignment_history:  # by default, the wrapper emits attention as output
        self.assertAllClose(
            # outputs are batch major but the stacked TensorArray is time major
            sess_results["state_alignment_history"],
            expected_final_alignment_history)

  def testBahdanauNotNormalized(self):
    create_attention_mechanism = wrapper.BahdanauAttention

    expected_final_output = BasicDecoderOutput(
        rnn_output=array(
            [[[
                2.04633363e-03, 1.89259532e-03, 2.09550979e-03, -3.81628517e-03,
                -4.36160620e-03, -6.43933658e-03
            ], [
                2.41885195e-03, 2.02089013e-03, 2.05879519e-03, -3.85483308e-03,
                -3.51473060e-03, -6.14458136e-03
            ], [
                2.02294230e-03, 2.06955452e-03, 2.34797411e-03, -3.62816593e-03,
                -3.80352931e-03, -6.27150526e-03
            ]], [[
                4.89025004e-03, -1.97221269e-03, 3.34283570e-03,
                -2.79326970e-03, 3.63148772e-03, -4.79645561e-03
            ], [
                5.13446378e-03, -2.03941623e-03, 3.51774949e-03,
                -2.83448119e-03, 3.14159272e-03, -5.31486655e-03
            ], [
                5.20701287e-03, -2.21262546e-03, 3.58187454e-03,
                -2.85831164e-03, 3.20822699e-03, -5.20829484e-03
            ]], [[
                -1.34046993e-03, -9.99792013e-04, -2.11631414e-03,
                -1.85202830e-03, -5.26227616e-03, -9.08544939e-03
            ], [
                -1.35486713e-03, -1.04408595e-03, -1.96779310e-03,
                -1.80004584e-03, -5.61304903e-03, -9.34211537e-03
            ], [
                -1.12452905e-03, -7.68281636e-04, -1.99770415e-03,
                -1.88058324e-03, -5.01882844e-03, -9.32228006e-03
            ]], [[
                1.52967637e-03, -3.97213362e-03, -9.64699371e-04,
                8.51419638e-04, -1.29806029e-03, 6.56482670e-03
            ], [
                1.22562144e-03, -4.56351135e-03, -1.08190742e-03,
                8.27267300e-04, -2.10060296e-03, 6.43097097e-03
            ], [
                9.93521884e-04, -4.37386986e-03, -1.41534151e-03,
                6.44790183e-04, -2.16482091e-03, 6.68301852e-03
            ]], [[
                -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
                -1.56512906e-04, 9.63474595e-05
            ], [
                -1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
                6.56487318e-05, -1.48634164e-04, -1.84347919e-05
            ], [
                1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
                2.73401442e-04, -2.69805576e-04
            ]]],
            dtype=float32),
        sample_id=array(
            [[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
            dtype=int32))

    expected_final_state = AttentionWrapperState(
        cell_state=LSTMStateTuple(
            c=array(
                [[
                    -2.18977481e-02, -8.04181397e-03, -1.48273818e-03,
                    1.61075518e-02, -1.37986457e-02, -7.57964421e-03,
                    -8.28644261e-03, -1.18742418e-02, 1.78838037e-02
                ], [
                    1.74201727e-02, -1.41931782e-02, -3.88098788e-03,
                    3.19711640e-02, -3.54694054e-02, -2.14694049e-02,
                    -6.21706853e-03, -1.69323490e-03, -1.94494929e-02
                ], [
                    -1.14532551e-02, 8.77828151e-03, -1.62972715e-02,
                    -1.39963031e-02, 1.34832524e-02, -1.04488730e-02,
                    6.16201758e-03, -9.41041857e-03, -6.57599326e-03
                ], [
                    -4.74753827e-02, -1.19123599e-02, -7.40140676e-05,
                    4.10552323e-02, -1.36711076e-03, 2.11795494e-02,
                    -2.80460101e-02, -5.44509329e-02, -2.91906092e-02
                ], [
                    2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
                    5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
                    -5.05525898e-03, 2.19191350e-02, 1.62497871e-02
                ]],
                dtype=float32),
            h=array(
                [[
                    -1.09847616e-02, -3.97357112e-03, -7.54502777e-04,
                    7.91223347e-03, -7.02199014e-03, -3.80705344e-03,
                    -4.22102772e-03, -6.05491130e-03, 8.92073940e-03
                ], [
                    8.68115202e-03, -7.16950046e-03, -1.88387593e-03,
                    1.62680726e-02, -1.76830068e-02, -1.06620435e-02,
                    -3.07523785e-03, -8.46023730e-04, -9.99386702e-03
                ], [
                    -5.71225956e-03, 4.50055022e-03, -8.07653368e-03,
                    -6.94842264e-03, 6.75687613e-03, -5.12083014e-03,
                    3.06244940e-03, -4.61752573e-03, -3.23935854e-03
                ], [
                    -2.37231534e-02, -5.88526297e-03, -3.72226204e-05,
                    2.01789513e-02, -6.75848918e-04, 1.06686372e-02,
                    -1.42624676e-02, -2.69628745e-02, -1.45034352e-02
                ], [
                    1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
                    2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
                    -2.54477374e-03, 1.11598391e-02, 7.94144534e-03
                ]],
                dtype=float32)),
        attention=array(
            [[
                0.00202294, 0.00206955, 0.00234797, -0.00362817, -0.00380353,
                -0.00627151
            ], [
                0.00520701, -0.00221263, 0.00358187, -0.00285831, 0.00320823,
                -0.00520829
            ], [
                -0.00112453, -0.00076828, -0.0019977, -0.00188058, -0.00501883,
                -0.00932228
            ], [
                0.00099352, -0.00437387, -0.00141534, 0.00064479, -0.00216482,
                0.00668302
            ], [
                0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
                -0.00026981
            ]],
            dtype=float32),
        time=3,
        alignment_history=())

    expected_final_alignment_history = [[[
        0.12586178, 0.12272788, 0.1271652, 0.12484902, 0.12484902, 0.12484902,
        0.12484902, 0.12484902
    ], [
        0.12612638, 0.12516938, 0.12478404, 0.12478404, 0.12478404, 0.12478404,
        0.12478404, 0.12478404
    ], [
        0.12595113, 0.12515794, 0.1255464, 0.1246689, 0.1246689, 0.1246689,
        0.1246689, 0.1246689
    ], [
        0.12492912, 0.12501013, 0.12501013, 0.12501013, 0.12501013, 0.12501013,
        0.12501013, 0.12501013
    ], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]], [[
        0.12586173, 0.12272781, 0.12716517, 0.12484905, 0.12484905, 0.12484905,
        0.12484905, 0.12484905
    ], [
        0.12612617, 0.1251694, 0.12478408, 0.12478408, 0.12478408, 0.12478408,
        0.12478408, 0.12478408
    ], [
        0.12595108, 0.12515777, 0.1255464, 0.12466895, 0.12466895, 0.12466895,
        0.12466895, 0.12466895
    ], [
        0.12492914, 0.12501012, 0.12501012, 0.12501012, 0.12501012, 0.12501012,
        0.12501012, 0.12501012
    ], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]], [[
        0.12586181, 0.12272815, 0.12716556, 0.12484891, 0.12484891, 0.12484891,
        0.12484891, 0.12484891
    ], [
        0.12612608, 0.12516941, 0.12478409, 0.12478409, 0.12478409, 0.12478409,
        0.12478409, 0.12478409
    ], [
        0.12595116, 0.12515792, 0.12554643, 0.1246689, 0.1246689, 0.1246689,
        0.1246689, 0.1246689
    ], [
        0.1249292, 0.12501012, 0.12501012, 0.12501012, 0.12501012, 0.12501012,
        0.12501012, 0.12501012
    ], [0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125, 0.125]]]

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        alignment_history=True,
        expected_final_alignment_history=expected_final_alignment_history,
        name="testBahdanauNotNormalized")

  def testBahdanauNormalized(self):
    create_attention_mechanism = functools.partial(
        wrapper.BahdanauAttention, normalize=True)

    expected_final_output = BasicDecoderOutput(
        rnn_output=array(
            [[[
                1.27064800e-02, 3.57783446e-03, 8.22613202e-03, -1.61504047e-03,
                -1.12555185e-02, -3.92740499e-03
            ], [
                1.30781950e-02, 3.70747922e-03, 8.18992872e-03, -1.65389013e-03,
                -1.04098395e-02, -3.63383139e-03
            ], [
                1.26833543e-02, 3.75790196e-03, 8.48123431e-03, -1.42690970e-03,
                -1.07016256e-02, -3.76088684e-03
            ]], [[
                6.88417302e-03, -2.04071682e-03, 4.17768257e-03,
                -4.51408979e-03, 4.90086433e-03, -6.85973791e-03
            ], [
                7.12782983e-03, -2.10783770e-03, 4.35227761e-03,
                -4.55496181e-03, 4.41066315e-03, -7.37757795e-03
            ], [
                7.20011396e-03, -2.28102156e-03, 4.41620918e-03,
                -4.57867794e-03, 4.47713351e-03, -7.27072079e-03
            ]], [[
                -2.20676698e-03, -1.43745833e-03, -1.99429039e-03,
                -1.44722988e-03, -7.45461835e-03, -9.80243273e-03
            ], [
                -2.22120387e-03, -1.48139545e-03, -1.84528576e-03,
                -1.39490096e-03, -7.80559657e-03, -1.00586927e-02
            ], [
                -1.99079141e-03, -1.20571791e-03, -1.87507609e-03,
                -1.47541985e-03, -7.21158786e-03, -1.00391749e-02
            ]], [[
                1.48755650e-03, -3.89118027e-03, -9.40889120e-04,
                8.36852356e-04, -1.28285377e-03, 6.41521579e-03
            ], [
                1.18351437e-03, -4.48258361e-03, -1.05809816e-03,
                8.12723883e-04, -2.08540238e-03, 6.28142804e-03
            ], [
                9.51444614e-04, -4.29300033e-03, -1.39154412e-03,
                6.30271854e-04, -2.14963360e-03, 6.53359853e-03
            ]], [[
                -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
                -1.56512906e-04, 9.63474595e-05
            ], [
                -1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
                6.56487318e-05, -1.48634164e-04, -1.84347919e-05
            ], [
                1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
                2.73401442e-04, -2.69805576e-04
            ]]],
            dtype=float32),
        sample_id=array(
            [[0, 0, 0], [0, 0, 0], [1, 3, 1], [5, 5, 5], [3, 3, 2]],
            dtype=int32))

    expected_final_state = AttentionWrapperState(
        cell_state=LSTMStateTuple(
            c=array(
                [[
                    -2.19953191e-02, -7.81358499e-03, -1.42740645e-03,
                    1.62037201e-02, -1.38600282e-02, -7.60386931e-03,
                    -8.42390209e-03, -1.18884994e-02, 1.78821683e-02
                ], [
                    1.74096227e-02, -1.41773149e-02, -3.89175024e-03,
                    3.19635086e-02, -3.54669318e-02, -2.14924756e-02,
                    -6.20695669e-03, -1.73213519e-03, -1.94583312e-02
                ], [
                    -1.14590004e-02, 8.76899902e-03, -1.62825100e-02,
                    -1.39863417e-02, 1.34333782e-02, -1.04652103e-02,
                    6.13503950e-03, -9.39247012e-03, -6.57595927e-03
                ], [
                    -4.74739373e-02, -1.19136302e-02, -7.36713409e-05,
                    4.10547927e-02, -1.36768632e-03, 2.11772211e-02,
                    -2.80480143e-02, -5.44514954e-02, -2.91903671e-02
                ], [
                    2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
                    5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
                    -5.05525898e-03, 2.19191350e-02, 1.62497871e-02
                ]],
                dtype=float32),
            h=array(
                [[
                    -1.10325804e-02, -3.86056723e-03, -7.26287195e-04,
                    7.95945339e-03, -7.05253659e-03, -3.81913339e-03,
                    -4.29130904e-03, -6.06246945e-03, 8.91948957e-03
                ], [
                    8.67583323e-03, -7.16136536e-03, -1.88911252e-03,
                    1.62639488e-02, -1.76817775e-02, -1.06735229e-02,
                    -3.07015004e-03, -8.65494134e-04, -9.99815390e-03
                ], [
                    -5.71519835e-03, 4.49585915e-03, -8.06909613e-03,
                    -6.94347266e-03, 6.73189852e-03, -5.12895826e-03,
                    3.04909074e-03, -4.60868096e-03, -3.23936995e-03
                ], [
                    -2.37224363e-02, -5.88588836e-03, -3.70502457e-05,
                    2.01787297e-02, -6.76134136e-04, 1.06674768e-02,
                    -1.42634623e-02, -2.69631669e-02, -1.45033086e-02
                ], [
                    1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
                    2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
                    -2.54477374e-03, 1.11598391e-02, 7.94144534e-03
                ]],
                dtype=float32)),
        attention=array(
            [[
                0.01268335, 0.0037579, 0.00848123, -0.00142691, -0.01070163,
                -0.00376089
            ], [
                0.00720011, -0.00228102, 0.00441621, -0.00457868, 0.00447713,
                -0.00727072
            ], [
                -0.00199079, -0.00120572, -0.00187508, -0.00147542, -0.00721159,
                -0.01003917
            ], [
                0.00095144, -0.004293, -0.00139154, 0.00063027, -0.00214963,
                0.0065336
            ], [
                0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
                -0.00026981
            ]],
            dtype=float32),
        time=3,
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        name="testBahdanauNormalized")

  def testLuongNotNormalized(self):
    create_attention_mechanism = wrapper.LuongAttention

    expected_final_output = BasicDecoderOutput(
        rnn_output=array(
            [[[
                1.74922391e-03, 1.85935036e-03, 1.90880906e-03, -3.96941090e-03,
                -4.17229906e-03, -6.65769773e-03
            ], [
                1.99638237e-03, 1.91135216e-03, 1.73234346e-03, -4.00905171e-03,
                -3.15058464e-03, -6.34974428e-03
            ], [
                2.08854163e-03, 2.13832827e-03, 2.49780947e-03, -3.52849509e-03,
                -3.96897132e-03, -6.12034509e-03
            ]], [[
                4.76492243e-03, -1.97180966e-03, 3.29327444e-03,
                -2.68205139e-03, 3.55229783e-03, -4.66645230e-03
            ], [
                5.24956919e-03, -2.00631656e-03, 3.53828911e-03,
                -2.96283513e-03, 3.20920302e-03, -5.43697737e-03
            ], [
                5.30424621e-03, -2.17913301e-03, 3.59509978e-03,
                -2.97106663e-03, 3.26450402e-03, -5.31189423e-03
            ]], [[
                -1.36440888e-03, -9.75572329e-04, -2.11284542e-03,
                -1.84616144e-03, -5.31351101e-03, -9.12462734e-03
            ], [
                -1.41863467e-03, -1.11081311e-03, -1.94056751e-03,
                -1.74311269e-03, -5.76282106e-03, -9.29267984e-03
            ], [
                -1.12129003e-03, -8.15156149e-04, -2.01535341e-03,
                -1.89556007e-03, -5.04226238e-03, -9.37188603e-03
            ]], [[
                1.55163277e-03, -4.01433324e-03, -9.77111282e-04,
                8.59013060e-04, -1.30598655e-03, 6.64281659e-03
            ], [
                1.26811734e-03, -4.64518648e-03, -1.10593368e-03,
                8.41954607e-04, -2.11594440e-03, 6.58190623e-03
            ], [
                1.02682540e-03, -4.43787826e-03, -1.43417739e-03,
                6.56281307e-04, -2.17684195e-03, 6.80128345e-03
            ]], [[
                -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
                -1.56512906e-04, 9.63474595e-05
            ], [
                -1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
                6.56487318e-05, -1.48634164e-04, -1.84347919e-05
            ], [
                1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
                2.73401442e-04, -2.69805576e-04
            ]]],
            dtype=float32),
        sample_id=array(
            [[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
            dtype=int32))

    expected_final_state = AttentionWrapperState(
        cell_state=LSTMStateTuple(
            c=array(
                [[
                    -2.18942575e-02, -8.05099495e-03, -1.48526859e-03,
                    1.61030665e-02, -1.37967104e-02, -7.57982396e-03,
                    -8.28088820e-03, -1.18743815e-02, 1.78839806e-02
                ], [
                    1.74203254e-02, -1.41929490e-02, -3.88103351e-03,
                    3.19709182e-02, -3.54691371e-02, -2.14697979e-02,
                    -6.21709181e-03, -1.69324467e-03, -1.94495786e-02
                ], [
                    -1.14536462e-02, 8.77809525e-03, -1.62965059e-02,
                    -1.39955431e-02, 1.34810507e-02, -1.04491040e-02,
                    6.16097450e-03, -9.40943789e-03, -6.57613343e-03
                ], [
                    -4.74765450e-02, -1.19113335e-02, -7.42897391e-05,
                    4.10555862e-02, -1.36665069e-03, 2.11814232e-02,
                    -2.80444007e-02, -5.44504896e-02, -2.91908123e-02
                ], [
                    2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
                    5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
                    -5.05525898e-03, 2.19191350e-02, 1.62497871e-02
                ]],
                dtype=float32),
            h=array(
                [[
                    -1.09830676e-02, -3.97811923e-03, -7.55793473e-04,
                    7.91002903e-03, -7.02103321e-03, -3.80714820e-03,
                    -4.21818346e-03, -6.05497835e-03, 8.92084371e-03
                ], [
                    8.68122280e-03, -7.16937613e-03, -1.88389909e-03,
                    1.62679367e-02, -1.76828820e-02, -1.06622437e-02,
                    -3.07524228e-03, -8.46030540e-04, -9.99389403e-03
                ], [
                    -5.71245840e-03, 4.50045895e-03, -8.07614625e-03,
                    -6.94804778e-03, 6.75577158e-03, -5.12094703e-03,
                    3.06193763e-03, -4.61703911e-03, -3.23943049e-03
                ], [
                    -2.37237271e-02, -5.88475820e-03, -3.73612711e-05,
                    2.01791357e-02, -6.75620860e-04, 1.06695695e-02,
                    -1.42616741e-02, -2.69626491e-02, -1.45035451e-02
                ], [
                    1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
                    2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
                    -2.54477374e-03, 1.11598391e-02, 7.94144534e-03
                ]],
                dtype=float32)),
        attention=array(
            [[
                0.00208854, 0.00213833, 0.00249781, -0.0035285, -0.00396897,
                -0.00612035
            ], [
                0.00530425, -0.00217913, 0.0035951, -0.00297107, 0.0032645,
                -0.00531189
            ], [
                -0.00112129, -0.00081516, -0.00201535, -0.00189556, -0.00504226,
                -0.00937189
            ], [
                0.00102683, -0.00443788, -0.00143418, 0.00065628, -0.00217684,
                0.00680128
            ], [
                0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
                -0.00026981
            ]],
            dtype=float32),
        time=3,
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
        name="testLuongNotNormalized")

  def testLuongScaled(self):
    create_attention_mechanism = functools.partial(
        wrapper.LuongAttention, scale=True)

    expected_final_output = BasicDecoderOutput(
        rnn_output=array(
            [[[
                1.74922391e-03, 1.85935036e-03, 1.90880906e-03, -3.96941090e-03,
                -4.17229906e-03, -6.65769773e-03
            ], [
                1.99638237e-03, 1.91135216e-03, 1.73234346e-03, -4.00905171e-03,
                -3.15058464e-03, -6.34974428e-03
            ], [
                2.08854163e-03, 2.13832827e-03, 2.49780947e-03, -3.52849509e-03,
                -3.96897132e-03, -6.12034509e-03
            ]], [[
                4.76492243e-03, -1.97180966e-03, 3.29327444e-03,
                -2.68205139e-03, 3.55229783e-03, -4.66645230e-03
            ], [
                5.24956919e-03, -2.00631656e-03, 3.53828911e-03,
                -2.96283513e-03, 3.20920302e-03, -5.43697737e-03
            ], [
                5.30424621e-03, -2.17913301e-03, 3.59509978e-03,
                -2.97106663e-03, 3.26450402e-03, -5.31189423e-03
            ]], [[
                -1.36440888e-03, -9.75572329e-04, -2.11284542e-03,
                -1.84616144e-03, -5.31351101e-03, -9.12462734e-03
            ], [
                -1.41863467e-03, -1.11081311e-03, -1.94056751e-03,
                -1.74311269e-03, -5.76282106e-03, -9.29267984e-03
            ], [
                -1.12129003e-03, -8.15156149e-04, -2.01535341e-03,
                -1.89556007e-03, -5.04226238e-03, -9.37188603e-03
            ]], [[
                1.55163277e-03, -4.01433324e-03, -9.77111282e-04,
                8.59013060e-04, -1.30598655e-03, 6.64281659e-03
            ], [
                1.26811734e-03, -4.64518648e-03, -1.10593368e-03,
                8.41954607e-04, -2.11594440e-03, 6.58190623e-03
            ], [
                1.02682540e-03, -4.43787826e-03, -1.43417739e-03,
                6.56281307e-04, -2.17684195e-03, 6.80128345e-03
            ]], [[
                -3.78854020e-04, 5.62231544e-05, 1.06837302e-04, 1.87137164e-04,
                -1.56512906e-04, 9.63474595e-05
            ], [
                -1.04306288e-04, -1.37411975e-04, 2.82689070e-05,
                6.56487318e-05, -1.48634164e-04, -1.84347919e-05
            ], [
                1.24452345e-04, 2.20821079e-04, 4.07114130e-04, 2.18028668e-04,
                2.73401442e-04, -2.69805576e-04
            ]]],
            dtype=float32),
        sample_id=array(
            [[2, 0, 2], [0, 0, 0], [1, 1, 1], [5, 5, 5], [3, 3, 2]],
            dtype=int32))

    expected_final_state = AttentionWrapperState(
        cell_state=LSTMStateTuple(
            c=array(
                [[
                    -2.18942575e-02, -8.05099495e-03, -1.48526859e-03,
                    1.61030665e-02, -1.37967104e-02, -7.57982396e-03,
                    -8.28088820e-03, -1.18743815e-02, 1.78839806e-02
                ], [
                    1.74203254e-02, -1.41929490e-02, -3.88103351e-03,
                    3.19709182e-02, -3.54691371e-02, -2.14697979e-02,
                    -6.21709181e-03, -1.69324467e-03, -1.94495786e-02
                ], [
                    -1.14536462e-02, 8.77809525e-03, -1.62965059e-02,
                    -1.39955431e-02, 1.34810507e-02, -1.04491040e-02,
                    6.16097450e-03, -9.40943789e-03, -6.57613343e-03
                ], [
                    -4.74765450e-02, -1.19113335e-02, -7.42897391e-05,
                    4.10555862e-02, -1.36665069e-03, 2.11814232e-02,
                    -2.80444007e-02, -5.44504896e-02, -2.91908123e-02
                ], [
                    2.25644894e-02, -1.40382675e-03, 1.92396250e-02,
                    5.49034867e-03, -1.27930511e-02, -3.15603940e-03,
                    -5.05525898e-03, 2.19191350e-02, 1.62497871e-02
                ]],
                dtype=float32),
            h=array(
                [[
                    -1.09830676e-02, -3.97811923e-03, -7.55793473e-04,
                    7.91002903e-03, -7.02103321e-03, -3.80714820e-03,
                    -4.21818346e-03, -6.05497835e-03, 8.92084371e-03
                ], [
                    8.68122280e-03, -7.16937613e-03, -1.88389909e-03,
                    1.62679367e-02, -1.76828820e-02, -1.06622437e-02,
                    -3.07524228e-03, -8.46030540e-04, -9.99389403e-03
                ], [
                    -5.71245840e-03, 4.50045895e-03, -8.07614625e-03,
                    -6.94804778e-03, 6.75577158e-03, -5.12094703e-03,
                    3.06193763e-03, -4.61703911e-03, -3.23943049e-03
                ], [
                    -2.37237271e-02, -5.88475820e-03, -3.73612711e-05,
                    2.01791357e-02, -6.75620860e-04, 1.06695695e-02,
                    -1.42616741e-02, -2.69626491e-02, -1.45035451e-02
                ], [
                    1.12585640e-02, -6.92534202e-04, 9.88917705e-03,
                    2.75237625e-03, -6.56115822e-03, -1.57997780e-03,
                    -2.54477374e-03, 1.11598391e-02, 7.94144534e-03
                ]],
                dtype=float32)),
        attention=array(
            [[
                0.00208854, 0.00213833, 0.00249781, -0.0035285, -0.00396897,
                -0.00612035
            ], [
                0.00530425, -0.00217913, 0.0035951, -0.00297107, 0.0032645,
                -0.00531189
            ], [
                -0.00112129, -0.00081516, -0.00201535, -0.00189556, -0.00504226,
                -0.00937189
            ], [
                0.00102683, -0.00443788, -0.00143418, 0.00065628, -0.00217684,
                0.00680128
            ], [
                0.00012445, 0.00022082, 0.00040711, 0.00021803, 0.0002734,
                -0.00026981
            ]],
            dtype=float32),
        time=3,
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_mechanism_depth=9,
        name="testLuongScaled")

  def testNotUseAttentionLayer(self):
    create_attention_mechanism = wrapper.BahdanauAttention

    expected_final_output = BasicDecoderOutput(
        rnn_output=array(
            [[[
                -0.24223405, -0.07791166, 0.15451428, 0.24738294, 0.30900395,
                -0.24685201, 0.04992372, 0.18749543, -0.15878429, -0.13678923
            ], [
                -0.2422339, -0.07791159, 0.15451418, 0.24738279, 0.30900383,
                -0.24685188, 0.04992369, 0.18749531, -0.15878411, -0.13678911
            ], [
                -0.2422343, -0.07791215, 0.15451413, 0.24738336, 0.30900475,
                -0.2468522, 0.04992349, 0.18749571, -0.158785, -0.13678965
            ]], [[
                0.40035266, 0.12299616, -0.06085059, -0.09197108, 0.11368551,
                -0.15302914, 0.00566157, -0.26885766, 0.08546552, 0.18886778
            ], [
                0.40035242, 0.12299603, -0.06085056, -0.09197091, 0.11368536,
                -0.15302882, 0.0056615, -0.26885763, 0.08546554, 0.18886763
            ], [
                0.40035242, 0.122996, -0.06085056, -0.09197087, 0.11368532,
                -0.1530287, 0.00566146, -0.26885769, 0.08546556, 0.18886761
            ]], [[
                -0.4311333, 0.07519469, -0.01551808, 0.1913045, -0.02693807,
                -0.21668895, -0.02155721, 0.0013397, 0.21180844, 0.25578707
            ], [
                -0.43113309, 0.07519454, -0.01551818, 0.19130446, -0.0269379,
                -0.21668854, -0.021557, 0.00133975, 0.21180828, 0.25578681
            ], [
                -0.43113324, 0.07519463, -0.01551815, 0.1913045, -0.02693798,
                -0.21668874, -0.02155712, 0.00133973, 0.21180835, 0.25578696
            ]], [[
                0.07059932, 0.16451572, 0.01174669, 0.04646531, 0.1427598,
                0.0794456, -0.10852993, 0.15306188, 0.02151393, -0.05590061
            ], [
                0.07059933, 0.16451576, 0.01174669, 0.04646532, 0.14275983,
                0.07944562, -0.10852996, 0.15306193, 0.02151394, -0.05590062
            ], [
                0.07059937, 0.16451585, 0.0117467, 0.04646534, 0.1427599,
                0.07944567, -0.10853001, 0.153062, 0.02151395, -0.05590065
            ]], [[0., 0., 0., 0., 0., 0., 0., 0., 0.,
                  0.], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                 [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]],
            dtype=float32),
        sample_id=array(
            [[4, 4, 4], [0, 0, 0], [9, 9, 9], [1, 1, 1], [0, 0, 0]],
            dtype=int32))

    expected_final_state = AttentionWrapperState(
        cell_state=LSTMStateTuple(
            c=array(
                [[
                    -0.0181195, -0.01675365, -0.00510353, 0.01559796,
                    -0.01251448, -0.00437002, -0.01243257, -0.01720199,
                    0.02274928
                ], [
                    0.01259979, -0.00839985, -0.00374037, 0.03136262,
                    -0.03486227, -0.02466441, -0.00496157, -0.00461032,
                    -0.02098336
                ], [
                    -0.00781067, 0.00315682, -0.0138283, -0.01149793,
                    0.00485562, -0.01343193, 0.0085915, -0.00632846, -0.01052086
                ], [
                    -0.04184828, -0.01223641, 0.0009445, 0.03911434, 0.0043249,
                    0.02220661, -0.03006243, -0.05418363, -0.02615385
                ], [
                    0.02282745, -0.00143833, 0.01918138, 0.00545033,
                    -0.01258384, -0.00303765, -0.00511231, 0.02166323,
                    0.01638841
                ]],
                dtype=float32),
            h=array(
                [[
                    -0.00910065, -0.00827571, -0.00259689, 0.00764857,
                    -0.00635579, -0.00218579, -0.00633918, -0.00875511,
                    0.01134532
                ], [
                    0.00626597, -0.004241, -0.00181303, 0.01597157, -0.0173375,
                    -0.01224921, -0.00244522, -0.00231299, -0.0107822
                ], [
                    -0.00391383, 0.00162017, -0.00682621, -0.00570264,
                    0.00244099, -0.00659772, 0.00426475, -0.00309861,
                    -0.00520028
                ], [
                    -0.02087484, -0.00603306, 0.00047561, 0.01920062,
                    0.00213875, 0.01115329, -0.0152659, -0.02687523, -0.01297523
                ], [
                    0.01138975, -0.00070959, 0.00986007, 0.0027323, -0.00645386,
                    -0.00152054, -0.00257339, 0.01103063, 0.00800891
                ]],
                dtype=float32)),
        attention=array(
            [[
                -0.2422343, -0.07791215, 0.15451413, 0.24738336, 0.30900475,
                -0.2468522, 0.04992349, 0.18749571, -0.158785, -0.13678965
            ], [
                0.40035242, 0.122996, -0.06085056, -0.09197087, 0.11368532,
                -0.1530287, 0.00566146, -0.26885769, 0.08546556, 0.18886761
            ], [
                -0.43113324, 0.07519463, -0.01551815, 0.1913045, -0.02693798,
                -0.21668874, -0.02155712, 0.00133973, 0.21180835, 0.25578696
            ], [
                0.07059937, 0.16451585, 0.0117467, 0.04646534, 0.1427599,
                0.07944567, -0.10853001, 0.153062, 0.02151395, -0.05590065
            ], [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],
            dtype=float32),
        time=3,
        alignment_history=())

    self._testWithAttention(
        create_attention_mechanism,
        expected_final_output,
        expected_final_state,
        attention_layer_size=None,
        name="testNotUseAttentionLayer")


if __name__ == "__main__":
  test.main()
